library(tidyverse)
library(fst)
library(lubridate)
library(data.table)
library(tictoc)

Remove_Deceased <- function(dataset) {
  
  DT <- data.table(dataset)
  DT_keys <- DT[, c("cid", "qtr")]
  
  DT[, dead_marker := fifelse(VARNAME == "N", NA_integer_, 1)]
  DT <- DT[order(cid, qtr)]
  DT[, ("dead_marker") := nafill(.SD, type = "locf"), .SDcols = "dead_marker", by = cid]
  DT <- DT[is.na(dead_marker)]
  
  DT
  
}

Assign_Consumer_Group <- function(dataset) {
  
  DT <- data.table(dataset)
  
  DT[, consumer_category := fcase(
    VARNAME + VARNAME + VARNAME + VARNAME + VARNAME == 0, "CC",
    (VARNAME + VARNAME > 0) & (VARNAME + VARNAME + VARNAME == 0), "CD",
    VARNAME + VARNAME + VARNAME > 0, "CF", 
    default = as.character(NA)
  )]
  
}

Create_Outcome_Variable <- function(dataset) {
  
  DT <- data.table(dataset)
  DT_keys <- DT[, c("cid", "qtr")]
  
  DT[,p_total := VARNAME + VARNAME + VARNAME]
  
  DT_keys_no_missing_qtr <- DT[,.(qtr=seq(min(qtr),max(qtr),"3 months")),cid]
  
  setkey(DT, cid, qtr)
  setkey(DT_keys_no_missing_qtr, cid, qtr)
  
  
  
  DT <- DT[DT_keys_no_missing_qtr, roll=0]
  DT <- DT[order(cid, qtr)]
  
  DT[, ("p_total") := frollsum(.SD, n = 8, na.rm = FALSE, align = "left"), .SDcols="p_total", by = cid]
  DT[, ("p_total") := shift(.SD, 1, type = "lead"), .SDcols="p_total", by = cid]
  DT[, t_default := fifelse(p_total > 0, 1, 0)]
  
  DT <- DT[DT_keys, nomatch=0]
  DT
}

path_cluster <- "/jbod/projects/Credit_Scoring_ML_MersaultMoultonSantucci/"

path_ccp <- paste0(path_cluster, "data_qtr_rand_no_cid_before_consumer_group_var/")

v_file_names <- list.files(path_ccp) 

df_file_names <- tibble(
  file_name = v_file_names
) %>%
  mutate(
    qtr = ymd(str_extract(file_name, "[0-9]{8}")),
    rand_no_cid = as.numeric(str_extract(file_name, "[0-9]{4}(?=.fst)"))
  ) %>% 
  glimpse()

for(rand_no_cid_ in v_rand_no_cid){
  
  tic()
  v_file_names_rand_no_cid <- df_file_names %>% 
    filter(rand_no_cid == rand_no_cid_) %>%
    pull(file_name)
  
  df_raw <- tibble()
  
  v_path_file_names <- paste0(path_ccp, v_file_names_rand_no_cid)
  
  # ~100 seconds
  tic()
  df_raw <- v_path_file_names %>% 
    map_df(~read_fst(.))
  toc()
  
  # set.seed(303)
  # v_cid_sample <- df_raw %>% 
  #   select(cid) %>% 
  #   distinct() %>% 
  #   pull(cid) %>% 
  #   sample(25000)
  
  tic()
  df <- df_raw %>% 
    #filter(cid %in% v_cid_sample) %>% 
    Remove_Deceased() %>%  
    Create_Outcome_Variable() %>% 
    Assign_Consumer_Group() %>% 
    filter(!is.na(t_default)) %>% 
    as_tibble()
  toc()
  
  # path_out <- paste0(path, "intermediate/data_qtr_rand_no_cid_with_outcome/")
  path_out_cluster <- paste0(path_cluster, "data_qtr_rand_no_cid_with_outcome/")
  
  # if(!dir.exists(path_out)){ dir.create(path_out, recursive = TRUE)}
  if(!dir.exists(path_out_cluster)){ dir.create(path_out_cluster, recursive = TRUE)}
  
  v_qtr <- df %>% 
    arrange(qtr) %>% 
    select(qtr) %>% 
    distinct() %>% 
    pull(qtr)
  
  #print(v_qtr[1])
  
  for (i_qtr in 1:length(v_qtr)) {
    #i_qtr <- 1
    qtr_ <- v_qtr[i_qtr]
    
    #qtr_ <- "2009-06-01"
    
    qtr_string <- str_replace_all(qtr_, "-", "") 
    rand_no_cid_string <- str_pad(rand_no_cid_, 4, "left", pad = "0")
    
    df_qtr <- df %>% 
      filter(qtr == qtr_)
    
    
    file_name_out <- paste0(path_out_cluster, "ccp_", qtr_string, "_", rand_no_cid_string, ".fst")
    write_fst(df_qtr, file_name_out)
    
  }
  print(rand_no_cid_)
  toc()
}



files <- df_file_names %>% 
  filter(rand_no_cid == 85) %>% 
  pull(file_name) %>% 
  paste0("data_qtr_rand_no_cid_before_consumer_group_var/", .) 

Read_Files <- function(file_name) {
  
  map(file_name, function(x) 
    {print(x) 
    read_fst(x)})
  
}

Read_Files(files)

test <- Safe_Read_Files(files)

df_training$t_default %>% table()



files <- tibble(
  file_names = fs::dir_ls("data_qtr_rand_no_cid_with_outcome/"),
  qtr = ymd(str_extract(file_names, "[0-9]{8}")),
  rand_no_cid = as.numeric(str_extract(file_names, "[0-9]{4}(?=.fst)"))
)


data <- files %>%
  # filter(rand_no_cid %in% 0) %>%
  pull(file_names) %>%
  map_dfr(., function(x) {
    print(x)
    read_fst(x)
  })

mean(data$t_default)

table(data$qtr, useNA = "always")
sum(is.infinite(data$qtr))

sum(is.na(data$qtr))

